package zone.suplog.rpc.rabbitmq.processor;

import com.rabbitmq.client.*;
import org.springframework.amqp.rabbit.connection.Connection;
import org.springframework.amqp.rabbit.core.RabbitTemplate;
import org.springframework.context.ApplicationContext;
import org.springframework.core.env.Environment;
import org.springframework.scheduling.concurrent.ThreadPoolTaskExecutor;
import org.springframework.util.Assert;
import zone.suplog.rpc.spi.RpcProviderProcessor;
import zone.suplog.rpc.spi.transfer.RpcRequest;
import zone.suplog.rpc.spi.transfer.RpcResponse;

import java.io.*;
import java.lang.reflect.InvocationTargetException;
import java.lang.reflect.Method;
import java.util.Objects;
import java.util.concurrent.ThreadPoolExecutor;
import java.util.concurrent.TimeoutException;
import java.util.function.Function;

public class RabbitRpcProviderProcessor extends RpcProviderProcessor {

    private ThreadPoolTaskExecutor threadPoolTaskExecutor;

    private final static String EXCHANGE_PROPERTY_NAME = "spring.suplog-rpc.rabbit-mq.provider.exchange";

    private final static String MAX_CONN_TIMEOUT_PROPERTY_NAME = "spring.suplog-rpc.rabbit-mq.provider.max-conn-timeout";

    private final static String REGISTRY_KEY_PROPERTY_NAME = "spring.suplog-rpc.rabbit-mq.provider.registry-key";

    /**
     * 核心线程数
     */
    private final static String CORE_POOL_SIZE_PROPERTY_NAME = "spring.suplog-rpc.rabbit-mq.provider.core-pool-size";

    /**
     * 最大线程数
     */
    private final static String MAX_POOL_SIZE_PROPERTY_NAME = "spring.suplog-rpc.rabbit-mq.provider.max-pool-size";

    /**
     * 队列最大长度 >=mainExecutor.maxSize
     */
    private final static String QUEUE_CAPACITY_PROPERTY_NAME = "spring.suplog-rpc.rabbit-mq.provider.queue-capacity";

    /**
     * 线程池维护线程所允许的空闲时间
     */
    private final static String KEEP_ALIVE_SECONDS_PROPERTY_NAME = "spring.suplog-rpc.rabbit-mq.provider.keep-alive-seconds";

    protected final static String APPLICATION_PROPERTY_NAME = "spring.application.name";

    private Long maxConnTimeout;

    private ApplicationContext applicationContext;

    private Function<Channel, Consumer> createConsumer = channel -> new DefaultConsumer(channel) {
        @Override
        public void handleDelivery(String consumerTag, Envelope envelope, AMQP.BasicProperties properties, byte[] body) {
            threadPoolTaskExecutor.execute(() -> {
                AMQP.BasicProperties replyProps = new AMQP.BasicProperties.Builder()
                        .correlationId(properties.getCorrelationId())
                        .build();
                RpcResponse rpcResponse = null;
                try {
                    ByteArrayInputStream byteArrayInputStream = new ByteArrayInputStream(body);
                    ObjectInputStream ois = new ObjectInputStream(byteArrayInputStream);
                    log.debug(Thread.currentThread().getId() + " Accept rpc request: " + envelope.getRoutingKey() + ",channelNumber " + channel.getChannelNumber());
                    RpcRequest rpcRequest = (RpcRequest) ois.readObject();
                    Object bean = applicationContext.getBean(rpcRequest.getTargetInterface(), rpcRequest.getArgs());
                    Method method = bean.getClass().getMethod(rpcRequest.getTargetMethod(), rpcRequest.getParamTypes());
                    Object invoke = method.invoke(bean, rpcRequest.getArgs());
                    log.debug("rpc request: " + rpcRequest);
                    rpcResponse = RpcResponse.success(invoke);
                    log.debug("rpc response: " + rpcResponse);
                } catch (InvocationTargetException e) {
                    log.error(e, e.getTargetException());
                    rpcResponse = RpcResponse.error(e.getTargetException());
                } catch (Exception e) {
                    log.error(e, e.fillInStackTrace());
                    rpcResponse = RpcResponse.error(e);
                } finally {
                    try {
                        ByteArrayOutputStream byteArrayOutputStream = new ByteArrayOutputStream();
                        ObjectOutputStream oos = new ObjectOutputStream(byteArrayOutputStream);
                        oos.writeObject(rpcResponse);
                        channel.confirmSelect();
                        channel.basicPublish("", properties.getReplyTo(), replyProps, byteArrayOutputStream.toByteArray());
                        channel.waitForConfirmsOrDie(maxConnTimeout);
                    } catch (InterruptedException e) {
                        log.error("信道中断:" + e);
                    } catch (TimeoutException e) {
                        log.error("信道连接超时:" + e);
                    } catch (IOException e) {
                        log.error("IO异常:" + e);
                    } catch (Exception e) {
                        log.error("异常:" + e);
                    }
                }
            });
        }
    };

    @Override
    protected void processInitialization(ApplicationContext applicationContext) {
        Environment environment = applicationContext.getBean(Environment.class);
        threadPoolInitialize(environment);
        String exchange = environment.getProperty(EXCHANGE_PROPERTY_NAME);
        Assert.hasText(exchange, "Property '" + EXCHANGE_PROPERTY_NAME + "' not found!");
        this.maxConnTimeout = environment.getProperty(MAX_CONN_TIMEOUT_PROPERTY_NAME, Long.class, 5000L);
        Assert.notNull(this.maxConnTimeout, "Property '" + MAX_CONN_TIMEOUT_PROPERTY_NAME + "' not found!");
        String applicationName = environment.getProperty(APPLICATION_PROPERTY_NAME);
        Assert.hasText(applicationName, "Property '" + APPLICATION_PROPERTY_NAME + "' not found!");
        Connection connection = applicationContext.getBean(RabbitTemplate.class).getConnectionFactory().createConnection();
        Assert.notNull(connection, "rabbitmq is can't connection!");
        String registryKey = environment.getProperty(REGISTRY_KEY_PROPERTY_NAME);
        Assert.hasText(registryKey, "Property '" + REGISTRY_KEY_PROPERTY_NAME + "' not found!");
        this.applicationContext = applicationContext;
        try {
            Channel channel = Objects.requireNonNull(connection.getDelegate()).createChannel();
            channel.queueDeclare(applicationName, false, false, true, null);
            channel.exchangeDeclare(exchange, BuiltinExchangeType.TOPIC);
            channel.queueBind(applicationName, exchange, registryKey.trim() + ".#");
            channel.basicConsume(applicationName, true, createConsumer.apply(channel));
            log.debug("channel [" + channel.getChannelNumber() + "] is open!");
        } catch (Exception e) {
            log.error(e);
        }
    }

    private void threadPoolInitialize(Environment environment) {
        this.threadPoolTaskExecutor = new ThreadPoolTaskExecutor();
        this.threadPoolTaskExecutor.setCorePoolSize(environment.getProperty(CORE_POOL_SIZE_PROPERTY_NAME, Integer.class, 8));
        this.threadPoolTaskExecutor.setMaxPoolSize(environment.getProperty(MAX_POOL_SIZE_PROPERTY_NAME, Integer.class, 200));
        this.threadPoolTaskExecutor.setQueueCapacity(environment.getProperty(QUEUE_CAPACITY_PROPERTY_NAME, Integer.class, 200));
        this.threadPoolTaskExecutor.setKeepAliveSeconds(environment.getProperty(KEEP_ALIVE_SECONDS_PROPERTY_NAME, Integer.class, 120));
        // 线程池对拒绝任务(无线程可用)的处理策略
        this.threadPoolTaskExecutor.setRejectedExecutionHandler(new ThreadPoolExecutor.CallerRunsPolicy());
        this.threadPoolTaskExecutor.initialize();
    }
}
